
# Codes from the source dataset:
# ---------------------------------------------------------------------------------------------
#!/usr/bin/python
# -*- coding: utf8 -*
#####################
# read_td_events.py #
#####################
# Feb 2017 - Jean-Matthieu Maro
# Email: jean-matthieu dot maro, hosted at inserm, which is located in FRance.
# Thanks to Germain Haessig and Laurent Dardelet.

from struct import unpack, pack
import numpy as np
import sys


def peek(f, length=1):
    pos = f.tell()
    data = f.read(length)
    f.seek(pos)
    return data

def readATIS_tddat(file_name, orig_at_zero = True, drop_negative_dt = True, verbose = True, events_restriction = [0, np.inf]):

    """
    reads ATIS td events in .dat format

    input:
    filename: string, path to the .dat file
    orig_at_zero: bool, if True, timestamps will start at 0
    drop_negative_dt: bool, if True, events with a timestamp greater than the previous event are dismissed
    verbose: bool, if True, verbose mode.
    events_restriction: list [min ts, max ts], will return only events with ts in the defined boundaries

    output:
    timestamps: numpy array of length (number of events), timestamps
    coords: numpy array of size (number of events, 2), spatial coordinates: col 0 is x, col 1 is y.
    polarities: numpy array of length (number of events), polarities
    removed_events: integer, number of removed events (negative delta-ts)

    """

    polmask = 0x0002000000000000
    xmask = 0x000001FF00000000
    ymask = 0x0001FE0000000000
    polpadding = 49
    ypadding = 41
    xpadding = 32

    # This one read _td.dat files generated by kAER
    if verbose:
        print('Reading _td dat file... (' + file_name + ')')
    file = open(file_name,'rb')

    header = False
    while peek(file) == b'%':
        file.readline()
        header = True
    if header:
        ev_type = unpack('B',file.read(1))[0]
        ev_size = unpack('B',file.read(1))[0]
        if verbose:
            print('> Header exists. Event type is ' + str(ev_type) + ', event size is ' + str(ev_size))
        if ev_size != 8:
            print('Wrong event size. Aborting.')
            return -1, -1, -1, -1
    else: # set default ev type and size
        if verbose:
            print('> No header. Setting default event type and size.')
        ev_size = 8
        ev_type = 0

    # Compute number of events in the file
    start = file.tell()
    file.seek(0,2)
    stop = file.tell()
    file.seek(start)

    Nevents = int( (stop-start)/ev_size )
    dNEvents = Nevents/100
    if verbose:
        print("> The file contains %d events." %Nevents)

    # store read data
    timestamps = np.zeros(Nevents, dtype = int)
    polarities = np.zeros(Nevents, dtype = int)
    coords = np.zeros((Nevents, 2), dtype = int)

    ActualEvents = 0
    for i in np.arange(0, int(Nevents)):

        event = unpack('Q',file.read(8))
        ts = event[0] & 0x00000000FFFFFFFF
        # padding = event[0] & 0xFFFC000000000000
        pol = (event[0] & polmask) >> polpadding
        y = (event[0] & ymask) >> ypadding
        x = (event[0] & xmask) >> xpadding
        if i >= events_restriction[0] and ts>=timestamps[max(0,i-1)]:
            ActualEvents += 1
            timestamps[i] = ts
            polarities[i] = pol
            coords[i, 0] = x
            coords[i, 1] = y

        if verbose and i%dNEvents == 0:
            sys.stdout.write("> "+str(i/dNEvents)+"% \r")
            sys.stdout.flush()
        if i > events_restriction[1]:
            break
    file.close()
    if verbose:
        print ("> After loading events, actually found {0} events.".format(ActualEvents))

    timestamps = timestamps[:ActualEvents]
    coords = coords[:ActualEvents, :]
    polarities = polarities[:ActualEvents]

    #check for negative timestamps
    for ts in timestamps:
        if ts < 0:
            print('Found a negative timestamp.')

    if orig_at_zero:
        timestamps = timestamps - timestamps[0]

    drop_sum = 0
    if drop_negative_dt:
        if verbose:
            print('> Looking for negative dts...')
        # first check if negative TS differences
        just_dropped = True
        nPasses = 0
        while just_dropped:
            nPasses += 1
            index_neg = []
            just_dropped = False
            ii = 0
            while ii < (timestamps.size - 1):
                dt = timestamps[ii+1] - timestamps[ii]
                if dt < 0:  # alors ts en ii+1 plus petit que ii
                    index_neg += [ii+1]
                    ii += 1
                    just_dropped = True
                if verbose and ii%dNEvents == 0:
                    sys.stdout.write("> "+str(ii/dNEvents)+"% (pass "+str(nPasses)+") \r")
                    sys.stdout.flush()
                ii += 1
            if len(index_neg) > 0:
                drop_sum += len(index_neg)
                index_neg = np.array(index_neg)
                timestamps = np.delete(timestamps, index_neg)
                polarities = np.delete(polarities, index_neg)
                coords = np.delete(coords, index_neg, axis = 0)
                if verbose:
                    print('> Removed {0} events in {1} passes.'.format(drop_sum, nPasses))
        removed_events = drop_sum
    else:
        removed_events = -1
    if verbose:
        print("> Sequence duration: {0:.2f}s, ts[0] = {1}, ts[{2}] = {3}.".format(float(timestamps[-1] - timestamps[0]) / 1e6, timestamps[0], len(timestamps)-1, timestamps[-1]))


    return timestamps, coords, polarities, removed_events
# ---------------------------------------------------------------------------------------------

from typing import Callable, Dict, Optional, Tuple
from .. import datasets as sjds
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import shutil
import time
from .. import configure
from ..datasets import np_savez



class NAVGestureWalk(sjds.NeuromorphicDatasetFolder):
    # 6 gestures: left, right, up, down, home, select.
    # 10 subjects, holding the phone in one hand (selfie mode) while walking indoor and outdoor
    def __init__(
            self,
            root: str,
            data_type: str = 'event',
            frames_number: int = None,
            split_by: str = None,
            duration: int = None,
            custom_integrate_function: Callable = None,
            custom_integrated_frames_dir_name: str = None,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
    ) -> None:
        """
        The Nav Gesture dataset, which is proposed by `Event-Based Gesture Recognition With Dynamic Background Suppression Using Smartphone Computational Capabilities <https://www.frontiersin.org/articles/10.3389/fnins.2020.00275/full>`_.

        Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
        """
        super().__init__(root, None, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform, target_transform)

    @staticmethod
    def resource_url_md5() -> list:
        '''
        :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5
        :rtype: list
        '''
        return [('navgesture-walk.zip', 'https://www.neuromorphic-vision.com/public/downloads/navgesture/navgesture-walk.zip', '5d305266f13005401959e819abe206f0')]

    @staticmethod
    def downloadable() -> bool:
        '''
        :return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
        :rtype: bool
        '''
        print('This dataset can not be downloaded now. Refer to https://github.com/fangwei123456/spikingjelly/issues/423 for more details.')
        return False

    @staticmethod
    def extract_downloaded_files(download_root: str, extract_root: str):
        '''
        :param download_root: Root directory path which saves downloaded dataset files
        :type download_root: str
        :param extract_root: Root directory path which saves extracted files from downloaded files
        :type extract_root: str
        :return: None

        This function defines how to extract download files.
        '''
        temp_ext_dir = os.path.join(download_root, 'temp_ext')
        os.mkdir(temp_ext_dir)
        print(f'Mkdir [{temp_ext_dir}].')
        extract_archive(os.path.join(download_root, 'navgesture-walk.zip'), temp_ext_dir)
        with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 4)) as tpe:
            sub_threads = []
            for zip_file in os.listdir(temp_ext_dir):
                if os.path.splitext(zip_file)[1] == '.zip':
                    zip_file = os.path.join(temp_ext_dir, zip_file)
                    print(f'Extract [{zip_file}] to [{extract_root}].')
                    sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root))

            for sub_thread in sub_threads:
                if sub_thread.exception():
                    print(sub_thread.exception())
                    exit(-1)

        shutil.rmtree(temp_ext_dir)
        print(f'Rmtree [{temp_ext_dir}].')

    @staticmethod
    def get_H_W() -> Tuple:
        '''
        :return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data.
            For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset.
        :rtype: tuple
        '''
        return 240, 304  # this camera is 240*320, but x.max() = 303. So, I set W = 304.


    @staticmethod
    def load_origin_data(file_name: str) -> Dict:
        t, xy, p, _ = readATIS_tddat(file_name, verbose=False)
        x = xy[:, 0]
        y = 239 - xy[:, 1]
        return {'t': t, 'x': x, 'y': y, 'p': p}

    @staticmethod
    def read_aedat_save_to_np(bin_file: str, np_file: str):
        t, xy, p, _ = readATIS_tddat(bin_file, verbose=False)
        x = xy[:, 0]
        y = 239 - xy[:, 1]
        np_savez(np_file,
                 t=t,
                 x=x,
                 y=y,
                 p=p
                 )
        print(f'Save [{bin_file}] to [{np_file}].')

    @staticmethod
    def create_events_np_files(extract_root: str, events_np_root: str):
        '''
        :param extract_root: Root directory path which saves extracted files from downloaded files
        :type extract_root: str
        :param events_np_root: Root directory path which saves events files in the ``npz`` format
        :type events_np_root:
        :return: None

        This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``.
        '''
        t_ckp = time.time()
        np_dir_dict = {}
        for label in ['le', 'ri', 'up', 'do', 'ho', 'se']:
            np_dir = os.path.join(events_np_root, label)
            os.mkdir(np_dir)
            print(f'Mkdir [{np_dir}].')
            np_dir_dict[label] = np_dir

        with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(),
                                                configure.max_threads_number_for_datasets_preprocess)) as tpe:
            for user_name in os.listdir(extract_root):
                sub_threads = []
                aedat_dir = os.path.join(extract_root, user_name)
                for bin_file in os.listdir(aedat_dir):
                    base_name = os.path.splitext(bin_file)[0]
                    label = base_name.split('_')[1]
                    source_file = os.path.join(aedat_dir, bin_file)
                    target_file = os.path.join(np_dir_dict[label], base_name + '.npz')
                    print(f'Start to convert [{source_file}] to [{target_file}].')

                    sub_threads.append(tpe.submit(NAVGestureWalk.read_aedat_save_to_np, source_file,
                               target_file))

                for sub_thread in sub_threads:
                    if sub_thread.exception():
                        print(sub_thread.exception())
                        exit(-1)
        print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')


class NAVGestureSit(NAVGestureWalk):
    @staticmethod
    def resource_url_md5() -> list:
        '''
        :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5
        :rtype: list
        '''
        return [('navgesture-sit.zip', 'https://www.neuromorphic-vision.com/public/downloads/navgesture/navgesture-sit.zip', '1571753ace4d9e0946e6503313712c22')]

    @staticmethod
    def extract_downloaded_files(download_root: str, extract_root: str):
        '''
        :param download_root: Root directory path which saves downloaded dataset files
        :type download_root: str
        :param extract_root: Root directory path which saves extracted files from downloaded files
        :type extract_root: str
        :return: None

        This function defines how to extract download files.
        '''
        temp_ext_dir = os.path.join(download_root, 'temp_ext')
        os.mkdir(temp_ext_dir)
        print(f'Mkdir [{temp_ext_dir}].')
        extract_archive(os.path.join(download_root, 'navgesture-sit.zip'), temp_ext_dir)
        with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 4)) as tpe:
            sub_threads = []
            for zip_file in os.listdir(temp_ext_dir):
                if os.path.splitext(zip_file)[1] == '.zip':
                    zip_file = os.path.join(temp_ext_dir, zip_file)
                    print(f'Extract [{zip_file}] to [{extract_root}].')
                    sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root))

            for sub_thread in sub_threads:
                if sub_thread.exception():
                    print(sub_thread.exception())
                    exit(-1)

        shutil.rmtree(temp_ext_dir)
        print(f'Rmtree [{temp_ext_dir}].')
